
# load packages
library(tidyverse)
library(sandwich)

set.seed(58233)


# load data
df <- read_rds("output/rescaled-data.rds") %>%
  glimpse() 

wts <- read_rds("output/weights.rds") %>%
  glimpse()


# fit model; specify so that product terms all represent ATE-hats for a policy (i.e., "topic")
fit <- lm(agree_measure ~ -1 + issue_id + issue_id:treat_indicator, data = df); fit

# extract the ATE-hats for each policy and rename to the Issue ##
delta_hat <- coef(fit)[str_detect(names(coef(fit)), ":treat_indicator")]
names(delta_hat) <- str_remove(names(delta_hat), ":treat_indicator")
names(delta_hat) <- str_remove(names(delta_hat), "issue_id")

# extract the estimated variances for each policy and rename to the Issue ##
st1 <- system.time({
  v_hat <- vcovBS(fit, cluster = ~ issue_id + respid, R = 2000, cores = 12)[str_detect(names(coef(fit)), ":treat_indicator"), str_detect(names(coef(fit)), ":treat_indicator")]
})
print(st1)


#v_hat <- vcovCL(fit, cluster = ~ issue_id + respid)[str_detect(names(coef(fit)), ":treat_indicator"), str_detect(names(coef(fit)), ":treat_indicator")]
#v_hat <- vcov(fit)[str_detect(names(coef(fit)), ":treat_indicator"), str_detect(names(coef(fit)), ":treat_indicator")]
rownames(v_hat) <- str_remove(rownames(v_hat), ":treat_indicator")
rownames(v_hat) <- str_remove(rownames(v_hat), "issue_id")
colnames(v_hat) <- str_remove(colnames(v_hat), ":treat_indicator")
colnames(v_hat) <- str_remove(colnames(v_hat), "issue_id")

# structure weights to match the estimates of ATE and variance
w_star <- distinct(select(df, policy, issue, category, issue_id)) %>%
  left_join(wts) %>%
  glimpse()
w <- w_star$design_weight
names(w) <- w_star$issue_id


# compute weighted average of the ATE estimates
issues <- sort(unique(df$issue_id))
bits <- NULL
iter <- 1
for (i in 1:length(issues)) {
    bits[iter] <- w[issues[i]]*delta_hat[issues[i]]
    iter <- iter + 1
}
deltaG_hat <- sum(bits)/sum(w); deltaG_hat

# combine (co)variances
bits <- NULL
iter <- 1
for (i in 1:length(issues)) {
  for (j in 1:length(issues)) {
    bits[iter] <- w[issues[i]]*w[issues[j]]*v_hat[issues[i], issues[j]]
    iter <- iter + 1
  }
}

v_deltaG_hat <- sum(bits); sqrt(v_deltaG_hat)/length(issues)


# hierarchical model
library(lme4) 
fit_hier <- lmer(agree_measure ~ treat_indicator + (1 + treat_indicator | issue_id) + (1 | respid), data = df)
arm::display(fit_hier)

library(tidybayes)
s0 <- tibble(issue_id = unique(df$issue_id),
             treat_indicator = 0) %>%
  mutate(.epred0 = predict(fit_hier, newdata = ., re.form = ~ (1 + treat_indicator | issue_id))) %>%
  select(-treat_indicator) %>%
  glimpse()

s1 <- tibble(issue_id = unique(df$issue_id),
               treat_indicator = 1) %>%
  mutate(.epred1 = predict(fit_hier, newdata = ., re.form = ~ (1 + treat_indicator | issue_id))) %>%
  select(-treat_indicator) %>%
  glimpse()

s_lmer <- full_join(s0, s1) %>%
  mutate(ate = .epred1 - .epred0) %>% 
  left_join(distinct(select(df, issue_id, stem))) %>%
  mutate(method = "lmer") %>%
  select(-starts_with(".epred")) %>%
  glimpse()


library(rstanarm); options(mc.cores = parallel::detectCores())
fit_stan <- stan_lmer(agree_measure ~ treat_indicator + (1 + treat_indicator | issue_id) + (1 | respid), data = df, 
                      iter = 2000)

library(tidybayes)
s0 <- tibble(issue_id = unique(df$issue_id),
             treat_indicator = 0,
             respid = "placeholder") %>%
  add_epred_draws(fit_stan, re_formula = ~ (1 + treat_indicator | issue_id)) %>%
  ungroup() %>%
  select(-treat_indicator) %>%
  rename(.epred0 = .epred) %>%
  glimpse()

s1 <- tibble(issue_id = unique(df$issue_id),
             treat_indicator = 1,
             respid = "placeholder") %>%
  add_epred_draws(fit_stan, re_formula = ~ (1 + treat_indicator | issue_id)) %>%
  ungroup() %>%
  select(-treat_indicator) %>%
  rename(.epred1 = .epred) %>%
  glimpse()

s_stan <- full_join(s0, s1) %>%
  mutate(ate = .epred1 - .epred0) %>%
  group_by(issue_id) %>%
  summarize(ate = mean(ate)) %>%
  left_join(distinct(select(df, issue_id, stem))) %>%
  mutate(method = "Stan") %>%
  glimpse()

s <- bind_rows(s_stan, s_lmer) %>%
  mutate(stem = reorder(stem, ate)) %>%
  glimpse()

ggplot(s, aes(x = ate, y = stem, color = method)) + 
  geom_point()



